#!/usr/bin/env python3
import sys, json, pathlib
import numpy as np, pandas as pd
from astropy.io import fits
from astropy.cosmology import FlatLambdaCDM

COSMO = FlatLambdaCDM(H0=70, Om0=0.3)
PIX_ARCSEC = 0.214  # (unused unless size cols appear in future)
RG_EDGES = [5.0, 7.5, 10.0, 12.5, 15.0]   # kpc
MS_EDGES = [10.2, 10.5, 10.8, 11.1]       # log10 M*
RG_MIDS  = [ (RG_EDGES[i]+RG_EDGES[i+1])/2 for i in range(len(RG_EDGES)-1) ]

def cm(names): return {str(n).lower(): str(n) for n in names}
def pick(m, opts):
    for o in opts:
        if o.lower() in m: return m[o.lower()]
    return None
def label(edges, i): return f"{edges[i]}–{edges[i+1]}"

def main():
    import argparse
    ap = argparse.ArgumentParser(description="Build KiDS lenses (Bright+LePhare) → data/lenses.csv")
    ap.add_argument("--bright", required=True)
    ap.add_argument("--lephare", required=True)
    ap.add_argument("--out", default="data/lenses.csv")
    ap.add_argument("--max-rows", type=int, default=None)
    args = ap.parse_args()

    # Open FITS
    Hb = fits.open(args.bright, memmap=True); B = Hb[1]
    Hl = fits.open(args.lephare, memmap=True); L = Hl[1]
    Bn, Ln = list(B.columns.names), list(L.columns.names)
    bm, lm = cm(Bn), cm(Ln)

    # Columns you actually have (from your printout)
    # RA/DEC in both; z in Bright as zphot_ANNz2; mass in LePhare (MASS_MED or MASS_BEST)
    raB = pick(bm, ["RAJ2000"])
    deB = pick(bm, ["DECJ2000"])
    zB  = pick(bm, ["zphot_ANNz2","ZPHOT","PHOTOZ"])  # include your real name first
    if not (raB and deB and zB):
        sys.stderr.write(f"Missing RA/DEC/z in Bright. Got: {Bn}\n")
        sys.exit(1)

    # Join key
    kB = pick(bm, ["ID"])
    kL = pick(lm, ["ID"])
    if not (kB and kL):
        sys.stderr.write("No common ID between Bright and LePhare.\n")
        sys.exit(1)

    # Mass (log10 M*)
    mL = pick(lm, ["MASS_MED","MASS_BEST","LOGMASS","LOGMSTAR"])
    if not mL:
        sys.stderr.write(f"No stellar-mass column in LePhare. Columns: {Ln}\n")
        sys.exit(1)

    # Build frames & join
    Bd = pd.DataFrame({
        "lens_id":  B.data[kB],
        "ra_deg":   B.data[raB].astype(float),
        "dec_deg":  B.data[deB].astype(float),
        "z_lens":   B.data[zB].astype(float),
        "MAG_AUTO_CALIB": B.data[pick(bm, ["MAG_AUTO_CALIB"])].astype(float) if pick(bm,["MAG_AUTO_CALIB"]) else np.nan,
    })
    Ld = pd.DataFrame({
        "lens_id":     L.data[kL],
        "Mstar_log10": L.data[mL].astype(float),
        "REDSHIFT_L":  L.data[pick(lm,["REDSHIFT"])].astype(float) if pick(lm,["REDSHIFT"]) else np.nan
    })
    D = Bd.merge(Ld, on="lens_id", how="inner")

    # Prefer Bright photo-z; if NaN, fall back to LePhare REDSHIFT
    D["z_lens"] = np.where(np.isfinite(D["z_lens"]), D["z_lens"], D["REDSHIFT_L"])

    # Clean essentials before binning
    D.replace([np.inf,-np.inf], np.nan, inplace=True)
    D = D.dropna(subset=["ra_deg","dec_deg","z_lens","Mstar_log10"])
    if args.max_rows: D = D.head(args.max_rows).copy()

    # ---- SIZE HANDLING ----
    # Check if any true size columns exist (they don't in your Bright, but future-proof):
    size_cols = [c for c in ["A_WORLD","B_WORLD","A_IMAGE","B_IMAGE","FLUX_RADIUS"] if c in Bn]
    if size_cols:
        # If you later obtain a Bright catalog with sizes, implement physical conversion here.
        pass

    # Fallback PROXY (clearly marked): within each M* bin, rank by MAG_AUTO_CALIB
    # and assign to fixed RG bins; set RG_kpc = bin midpoint.
    # (This enables T3 to run and certify plateaus; for the *real* size–amplitude claim,
    # replace with a true size column later.)
    # Bin M*:
    ms_idx = np.digitize(D["Mstar_log10"].to_numpy(), MS_EDGES, right=False) - 1
    D = D.loc[(ms_idx>=0)&(ms_idx<len(MS_EDGES)-1)].copy()
    D["Mstar_bin"] = [label(MS_EDGES,i) for i in ms_idx[(ms_idx>=0)&(ms_idx<len(MS_EDGES)-1)]]

    # For each M* bin, split MAG_AUTO_CALIB into 4 quantile bins → map to RG bins
    RG_LABELS = [label(RG_EDGES,i) for i in range(len(RG_EDGES)-1)]
    RG_MIDMAP = dict(zip(RG_LABELS, RG_MIDS))

    def assign_rg_quantiles(g):
        # lower mag = brighter; we *assume* brighter tends to be larger (proxy only)
        q = g["MAG_AUTO_CALIB"].rank(method="first")
        k = len(g)
        # 4 equal-count bins
        edges = [0, 0.25*k, 0.5*k, 0.75*k, k+1]
        bins = np.digitize(q, edges, right=False) - 1
        bins = np.clip(bins, 0, 3)
        g = g.copy()
        g["R_G_bin"] = [RG_LABELS[i] for i in bins]
        g["R_G_kpc"] = [RG_MIDS[i] for i in bins]
        return g

    D = D.groupby("Mstar_bin", group_keys=False).apply(assign_rg_quantiles)

    # Output with exact headers
    out = D[["lens_id","ra_deg","dec_deg","z_lens","R_G_kpc","Mstar_log10","R_G_bin","Mstar_bin"]].copy()
    out.to_csv(args.out, index=False)

    pathlib.Path("outputs").mkdir(exist_ok=True)
    json.dump({"R_G_edges_kpc":RG_EDGES, "Mstar_edges":MS_EDGES,
               "note":"R_G_kpc is a MAG_AUTO_CALIB rank-based proxy (bin midpoints). Replace with true sizes for science claim."},
              open("outputs/bin_edges.json","w"))
    print(f"Wrote {args.out} with {len(out)} rows.")
    try:
        print(out.groupby(["R_G_bin","Mstar_bin"]).size().to_string())
    except Exception:
        pass

if __name__ == "__main__":
    main()
